A sample implemention of Siamese Network using Keras: https://github.com/keras-team/keras/blob/master/examples/mnist_siamese.py


In [41]:
'''Trains a Siamese MLP on pairs of digits from the MNIST dataset.
It follows Hadsell-et-al.'06 [1] by computing the Euclidean distance on the
output of the shared network and by optimizing the contrastive loss (see paper
for mode details).
# References
- Dimensionality Reduction by Learning an Invariant Mapping
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
Gets to 97.2% test accuracy after 20 epochs.
2 seconds per epoch on a Titan X Maxwell GPU
'''
from __future__ import absolute_import
from __future__ import print_function
import numpy as np

import random
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda
from keras.optimizers import RMSprop
from keras import backend as K

num_classes = 10
epochs = 20


def euclidean_distance(vects):
    x, y = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
    '''Contrastive loss from Hadsell-et-al.'06
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    '''
    margin = 1
    sqaure_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
    '''Positive and negative pair creation.
    Alternates between positive and negative pairs.
    '''
    pairs = []
    labels = []
    n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
    for d in range(num_classes):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = random.randrange(1, num_classes)
            dn = (d + inc) % num_classes
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
    '''Base network to be shared (eq. to feature extraction).
    '''
    input = Input(shape=input_shape)
    x = Flatten()(input)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    return Model(input, x)


def compute_accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    pred = y_pred.ravel() < 0.5
    return np.mean(pred == y_true)


def accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:]

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

# network definition
base_network = create_base_network(input_shape)

base_network.summary()
input_a = Input(shape=input_shape)
input_b = Input(shape=input_shape)

# because we re-use the same instance `base_network`,
# the weights of the network
# will be shared across the two branches
processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
                  output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
          batch_size=128,
          epochs=epochs,
          validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_72 (InputLayer)        (None, 28, 28)            0         
_________________________________________________________________
flatten_25 (Flatten)         (None, 784)               0         
_________________________________________________________________
dense_63 (Dense)             (None, 128)               100480    
_________________________________________________________________
dropout_39 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_64 (Dense)             (None, 128)               16512     
_________________________________________________________________
dropout_40 (Dropout)         (None, 128)               0         
_________________________________________________________________
dense_65 (Dense)             (None, 128)               16512     
=================================================================
Total params: 133,504
Trainable params: 133,504
Non-trainable params: 0
_________________________________________________________________
Train on 108400 samples, validate on 17820 samples
Epoch 1/20
108400/108400 [==============================] - 23s 212us/step - loss: 0.0968 - accuracy: 0.8853 - val_loss: 0.0530 - val_accuracy: 0.9448
Epoch 2/20
108400/108400 [==============================] - 16s 149us/step - loss: 0.0406 - accuracy: 0.9596 - val_loss: 0.0355 - val_accuracy: 0.9636
Epoch 3/20
108400/108400 [==============================] - 17s 154us/step - loss: 0.0280 - accuracy: 0.9725 - val_loss: 0.0308 - val_accuracy: 0.9659
Epoch 4/20
108400/108400 [==============================] - 16s 150us/step - loss: 0.0224 - accuracy: 0.9781 - val_loss: 0.0291 - val_accuracy: 0.9669
Epoch 5/20
108400/108400 [==============================] - 16s 147us/step - loss: 0.0194 - accuracy: 0.9806 - val_loss: 0.0287 - val_accuracy: 0.9681
Epoch 6/20
108400/108400 [==============================] - 17s 156us/step - loss: 0.0171 - accuracy: 0.9827 - val_loss: 0.0280 - val_accuracy: 0.9687
Epoch 7/20
108400/108400 [==============================] - 16s 152us/step - loss: 0.0157 - accuracy: 0.9846 - val_loss: 0.0282 - val_accuracy: 0.9682
Epoch 8/20
108400/108400 [==============================] - 15s 141us/step - loss: 0.0142 - accuracy: 0.9859 - val_loss: 0.0279 - val_accuracy: 0.9692
Epoch 9/20
108400/108400 [==============================] - 17s 154us/step - loss: 0.0132 - accuracy: 0.9865 - val_loss: 0.0293 - val_accuracy: 0.9678
Epoch 10/20
108400/108400 [==============================] - 17s 156us/step - loss: 0.0125 - accuracy: 0.9875 - val_loss: 0.0274 - val_accuracy: 0.9696
Epoch 11/20
108400/108400 [==============================] - 17s 158us/step - loss: 0.0116 - accuracy: 0.9882 - val_loss: 0.0282 - val_accuracy: 0.9679
Epoch 12/20
108400/108400 [==============================] - 18s 162us/step - loss: 0.0110 - accuracy: 0.9891 - val_loss: 0.0283 - val_accuracy: 0.9674
Epoch 13/20
108400/108400 [==============================] - 15s 142us/step - loss: 0.0108 - accuracy: 0.9889 - val_loss: 0.0281 - val_accuracy: 0.9684
Epoch 14/20
108400/108400 [==============================] - 17s 161us/step - loss: 0.0104 - accuracy: 0.9895 - val_loss: 0.0286 - val_accuracy: 0.9668
Epoch 15/20
108400/108400 [==============================] - 19s 175us/step - loss: 0.0098 - accuracy: 0.9899 - val_loss: 0.0285 - val_accuracy: 0.9685
Epoch 16/20
108400/108400 [==============================] - 18s 164us/step - loss: 0.0097 - accuracy: 0.9903 - val_loss: 0.0278 - val_accuracy: 0.9686
Epoch 17/20
108400/108400 [==============================] - 18s 163us/step - loss: 0.0093 - accuracy: 0.9905 - val_loss: 0.0275 - val_accuracy: 0.9690
Epoch 18/20
108400/108400 [==============================] - 18s 166us/step - loss: 0.0090 - accuracy: 0.9909 - val_loss: 0.0279 - val_accuracy: 0.9685
Epoch 19/20
108400/108400 [==============================] - 18s 164us/step - loss: 0.0090 - accuracy: 0.9908 - val_loss: 0.0298 - val_accuracy: 0.9674
Epoch 20/20
108400/108400 [==============================] - 17s 161us/step - loss: 0.0087 - accuracy: 0.9912 - val_loss: 0.0294 - val_accuracy: 0.9679
* Accuracy on training set: 99.11%
* Accuracy on test set: 96.69%

A litter modification demostrates how to apply your own model. Be carefull your input_shape, especially the channel.

I used the ResNet50 implemention based on Andrew Ng's course on Coursera

even that a is huge network, which is unsatisfied with such samll size(28*28), however, it only for show how to do this. check blew:


In [43]:
from __future__ import absolute_import
from __future__ import print_function
import numpy as np

import random
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Input, Flatten, Dense, Dropout, Lambda, Conv2D
from keras.optimizers import RMSprop
from keras import backend as K
from Siam_model import *

num_classes = 10
epochs = 1


def euclidean_distance(vects):
    x, y = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):

    margin = 1
    sqaure_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean(y_true * sqaure_pred + (1 - y_true) * margin_square)


def create_pairs(x, digit_indices):
    '''Positive and negative pair creation.
    Alternates between positive and negative pairs.
    '''
    pairs = []
    labels = []
    n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1
    for d in range(num_classes):
        for i in range(n):
            z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
            pairs += [[x[z1], x[z2]]]
            inc = random.randrange(1, num_classes)
            dn = (d + inc) % num_classes
            z1, z2 = digit_indices[d][i], digit_indices[dn][i]
            pairs += [[x[z1], x[z2]]]
            labels += [1, 0]
    return np.array(pairs), np.array(labels)


def create_base_network(input_shape):
    '''Base network to be shared (eq. to feature extraction).
    '''
    input = Input(shape=input_shape)
    x = Conv2D(1, (3, 3), strides = (1,1), padding = 'valid')(input)
    x = Flatten()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.1)(x)
    x = Dense(128, activation='relu')(x)
    return Model(input, x)


def compute_accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    pred = y_pred.ravel() < 0.5
    return np.mean(pred == y_true)


def accuracy(y_true, y_pred):
    '''Compute classification accuracy with a fixed threshold on distances.
    '''
    return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))


# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
input_shape = x_train.shape[1:] # (28, 28)

# add chanel to input_shape, (28, 28) => (28, 28, 1) which satify to My model ResNet50
# tuple operation
chanel = (1,)
input_shape = input_shape + chanel

# create training+test positive and negative pairs
digit_indices = [np.where(y_train == i)[0] for i in range(num_classes)]
tr_pairs, tr_y = create_pairs(x_train, digit_indices)

#expend dims, expend data dims form (samples, 2, 28, 28) => (samples, 2, 28, 28, 1)
tr_pairs = np.expand_dims(tr_pairs, axis=4)

#print(tr_pairs.shape)
digit_indices = [np.where(y_test == i)[0] for i in range(num_classes)]
te_pairs, te_y = create_pairs(x_test, digit_indices)

#expend dims, expend data dims form (samples, 2, 28, 28) => (samples, 2, 28, 28, 1), same to tr_pairs
te_pairs = np.expand_dims(te_pairs, axis=4)

# network definition
#base_network = create_base_network(input_shape)
#base_network.summary()

base_network = ResNet50(input_shape = input_shape, classes = 128)
base_network.summary()
input_a = Input(shape=input_shape) # Tensor("input_27:0", shape=(?, 28, 28, 1), dtype=float32)
input_b = Input(shape=input_shape)
#print(input_a)

processed_a = base_network(input_a)
processed_b = base_network(input_b)

distance = Lambda(euclidean_distance,
                  output_shape=eucl_dist_output_shape)([processed_a, processed_b])

model = Model([input_a, input_b], distance)

# train
rms = RMSprop()
model.compile(loss=contrastive_loss, optimizer=rms, metrics=[accuracy])
model.fit([tr_pairs[:, 0], tr_pairs[:, 1]], tr_y,
          batch_size=128,
          epochs=epochs,
          validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y))

# compute final accuracy on training and test sets
y_pred = model.predict([tr_pairs[:, 0], tr_pairs[:, 1]])
tr_acc = compute_accuracy(tr_y, y_pred)
y_pred = model.predict([te_pairs[:, 0], te_pairs[:, 1]])
te_acc = compute_accuracy(te_y, y_pred)

print('* Accuracy on training set: %0.2f%%' % (100 * tr_acc))
print('* Accuracy on test set: %0.2f%%' % (100 * te_acc))


__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_78 (InputLayer)           (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
zero_padding2d_7 (ZeroPadding2D (None, 34, 34, 1)    0           input_78[0][0]                   
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 14, 14, 64)   3200        zero_padding2d_7[0][0]           
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 14, 14, 64)   256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_295 (Activation)     (None, 14, 14, 64)   0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D)  (None, 6, 6, 64)     0           activation_295[0][0]             
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 6, 6, 64)     4160        max_pooling2d_7[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 6, 6, 64)     256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_296 (Activation)     (None, 6, 6, 64)     0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 6, 6, 64)     36928       activation_296[0][0]             
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 6, 6, 64)     256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_297 (Activation)     (None, 6, 6, 64)     0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 6, 6, 256)    16640       activation_297[0][0]             
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 6, 6, 256)    16640       max_pooling2d_7[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 6, 6, 256)    1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 6, 6, 256)    1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_97 (Add)                    (None, 6, 6, 256)    0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_298 (Activation)     (None, 6, 6, 256)    0           add_97[0][0]                     
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 6, 6, 64)     16448       activation_298[0][0]             
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 6, 6, 64)     256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_299 (Activation)     (None, 6, 6, 64)     0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 6, 6, 64)     36928       activation_299[0][0]             
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 6, 6, 64)     256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_300 (Activation)     (None, 6, 6, 64)     0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 6, 6, 256)    16640       activation_300[0][0]             
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 6, 6, 256)    1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_98 (Add)                    (None, 6, 6, 256)    0           bn2b_branch2c[0][0]              
                                                                 activation_298[0][0]             
__________________________________________________________________________________________________
activation_301 (Activation)     (None, 6, 6, 256)    0           add_98[0][0]                     
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 6, 6, 64)     16448       activation_301[0][0]             
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 6, 6, 64)     256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_302 (Activation)     (None, 6, 6, 64)     0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 6, 6, 64)     36928       activation_302[0][0]             
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 6, 6, 64)     256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_303 (Activation)     (None, 6, 6, 64)     0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 6, 6, 256)    16640       activation_303[0][0]             
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 6, 6, 256)    1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_99 (Add)                    (None, 6, 6, 256)    0           bn2c_branch2c[0][0]              
                                                                 activation_301[0][0]             
__________________________________________________________________________________________________
activation_304 (Activation)     (None, 6, 6, 256)    0           add_99[0][0]                     
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 3, 3, 128)    32896       activation_304[0][0]             
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 3, 3, 128)    512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_305 (Activation)     (None, 3, 3, 128)    0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 3, 3, 128)    147584      activation_305[0][0]             
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 3, 3, 128)    512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_306 (Activation)     (None, 3, 3, 128)    0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 3, 3, 512)    66048       activation_306[0][0]             
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 3, 3, 512)    131584      activation_304[0][0]             
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 3, 3, 512)    2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 3, 3, 512)    2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_100 (Add)                   (None, 3, 3, 512)    0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_307 (Activation)     (None, 3, 3, 512)    0           add_100[0][0]                    
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 3, 3, 128)    65664       activation_307[0][0]             
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 3, 3, 128)    512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_308 (Activation)     (None, 3, 3, 128)    0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 3, 3, 128)    147584      activation_308[0][0]             
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 3, 3, 128)    512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_309 (Activation)     (None, 3, 3, 128)    0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 3, 3, 512)    66048       activation_309[0][0]             
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 3, 3, 512)    2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_101 (Add)                   (None, 3, 3, 512)    0           bn3b_branch2c[0][0]              
                                                                 activation_307[0][0]             
__________________________________________________________________________________________________
activation_310 (Activation)     (None, 3, 3, 512)    0           add_101[0][0]                    
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 3, 3, 128)    65664       activation_310[0][0]             
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 3, 3, 128)    512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_311 (Activation)     (None, 3, 3, 128)    0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 3, 3, 128)    147584      activation_311[0][0]             
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 3, 3, 128)    512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_312 (Activation)     (None, 3, 3, 128)    0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 3, 3, 512)    66048       activation_312[0][0]             
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 3, 3, 512)    2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_102 (Add)                   (None, 3, 3, 512)    0           bn3c_branch2c[0][0]              
                                                                 activation_310[0][0]             
__________________________________________________________________________________________________
activation_313 (Activation)     (None, 3, 3, 512)    0           add_102[0][0]                    
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 3, 3, 128)    65664       activation_313[0][0]             
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 3, 3, 128)    512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_314 (Activation)     (None, 3, 3, 128)    0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 3, 3, 128)    147584      activation_314[0][0]             
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 3, 3, 128)    512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_315 (Activation)     (None, 3, 3, 128)    0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 3, 3, 512)    66048       activation_315[0][0]             
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 3, 3, 512)    2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_103 (Add)                   (None, 3, 3, 512)    0           bn3d_branch2c[0][0]              
                                                                 activation_313[0][0]             
__________________________________________________________________________________________________
activation_316 (Activation)     (None, 3, 3, 512)    0           add_103[0][0]                    
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 2, 2, 256)    131328      activation_316[0][0]             
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 2, 2, 256)    1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_317 (Activation)     (None, 2, 2, 256)    0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 2, 2, 256)    590080      activation_317[0][0]             
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 2, 2, 256)    1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_318 (Activation)     (None, 2, 2, 256)    0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 2, 2, 1024)   263168      activation_318[0][0]             
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 2, 2, 1024)   525312      activation_316[0][0]             
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 2, 2, 1024)   4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 2, 2, 1024)   4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_104 (Add)                   (None, 2, 2, 1024)   0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_319 (Activation)     (None, 2, 2, 1024)   0           add_104[0][0]                    
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 2, 2, 256)    262400      activation_319[0][0]             
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 2, 2, 256)    1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_320 (Activation)     (None, 2, 2, 256)    0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 2, 2, 256)    590080      activation_320[0][0]             
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 2, 2, 256)    1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_321 (Activation)     (None, 2, 2, 256)    0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 2, 2, 1024)   263168      activation_321[0][0]             
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 2, 2, 1024)   4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_105 (Add)                   (None, 2, 2, 1024)   0           bn4b_branch2c[0][0]              
                                                                 activation_319[0][0]             
__________________________________________________________________________________________________
activation_322 (Activation)     (None, 2, 2, 1024)   0           add_105[0][0]                    
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 2, 2, 256)    262400      activation_322[0][0]             
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 2, 2, 256)    1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_323 (Activation)     (None, 2, 2, 256)    0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 2, 2, 256)    590080      activation_323[0][0]             
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 2, 2, 256)    1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_324 (Activation)     (None, 2, 2, 256)    0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 2, 2, 1024)   263168      activation_324[0][0]             
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 2, 2, 1024)   4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_106 (Add)                   (None, 2, 2, 1024)   0           bn4c_branch2c[0][0]              
                                                                 activation_322[0][0]             
__________________________________________________________________________________________________
activation_325 (Activation)     (None, 2, 2, 1024)   0           add_106[0][0]                    
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 2, 2, 256)    262400      activation_325[0][0]             
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 2, 2, 256)    1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_326 (Activation)     (None, 2, 2, 256)    0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 2, 2, 256)    590080      activation_326[0][0]             
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 2, 2, 256)    1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_327 (Activation)     (None, 2, 2, 256)    0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 2, 2, 1024)   263168      activation_327[0][0]             
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 2, 2, 1024)   4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_107 (Add)                   (None, 2, 2, 1024)   0           bn4d_branch2c[0][0]              
                                                                 activation_325[0][0]             
__________________________________________________________________________________________________
activation_328 (Activation)     (None, 2, 2, 1024)   0           add_107[0][0]                    
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 2, 2, 256)    262400      activation_328[0][0]             
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 2, 2, 256)    1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_329 (Activation)     (None, 2, 2, 256)    0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 2, 2, 256)    590080      activation_329[0][0]             
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 2, 2, 256)    1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_330 (Activation)     (None, 2, 2, 256)    0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 2, 2, 1024)   263168      activation_330[0][0]             
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 2, 2, 1024)   4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_108 (Add)                   (None, 2, 2, 1024)   0           bn4e_branch2c[0][0]              
                                                                 activation_328[0][0]             
__________________________________________________________________________________________________
activation_331 (Activation)     (None, 2, 2, 1024)   0           add_108[0][0]                    
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 2, 2, 256)    262400      activation_331[0][0]             
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 2, 2, 256)    1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_332 (Activation)     (None, 2, 2, 256)    0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 2, 2, 256)    590080      activation_332[0][0]             
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 2, 2, 256)    1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_333 (Activation)     (None, 2, 2, 256)    0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 2, 2, 1024)   263168      activation_333[0][0]             
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 2, 2, 1024)   4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_109 (Add)                   (None, 2, 2, 1024)   0           bn4f_branch2c[0][0]              
                                                                 activation_331[0][0]             
__________________________________________________________________________________________________
activation_334 (Activation)     (None, 2, 2, 1024)   0           add_109[0][0]                    
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 1, 1, 512)    524800      activation_334[0][0]             
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 1, 1, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_335 (Activation)     (None, 1, 1, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 1, 1, 512)    2359808     activation_335[0][0]             
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 1, 1, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_336 (Activation)     (None, 1, 1, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 1, 1, 2048)   1050624     activation_336[0][0]             
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 1, 1, 2048)   2099200     activation_334[0][0]             
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 1, 1, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 1, 1, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_110 (Add)                   (None, 1, 1, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_337 (Activation)     (None, 1, 1, 2048)   0           add_110[0][0]                    
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 1, 1, 512)    1049088     activation_337[0][0]             
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 1, 1, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_338 (Activation)     (None, 1, 1, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 1, 1, 512)    2359808     activation_338[0][0]             
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 1, 1, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_339 (Activation)     (None, 1, 1, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 1, 1, 2048)   1050624     activation_339[0][0]             
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 1, 1, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_111 (Add)                   (None, 1, 1, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_337[0][0]             
__________________________________________________________________________________________________
activation_340 (Activation)     (None, 1, 1, 2048)   0           add_111[0][0]                    
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 1, 1, 512)    1049088     activation_340[0][0]             
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 1, 1, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_341 (Activation)     (None, 1, 1, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 1, 1, 512)    2359808     activation_341[0][0]             
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 1, 1, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_342 (Activation)     (None, 1, 1, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 1, 1, 2048)   1050624     activation_342[0][0]             
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 1, 1, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_112 (Add)                   (None, 1, 1, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_340[0][0]             
__________________________________________________________________________________________________
activation_343 (Activation)     (None, 1, 1, 2048)   0           add_112[0][0]                    
__________________________________________________________________________________________________
flatten_27 (Flatten)            (None, 2048)         0           activation_343[0][0]             
__________________________________________________________________________________________________
dense_67 (Dense)                (None, 128)          262272      flatten_27[0][0]                 
==================================================================================================
Total params: 23,843,712
Trainable params: 23,790,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
Train on 108400 samples, validate on 17820 samples
Epoch 1/1
108400/108400 [==============================] - 137402s 1s/step - loss: 1.0951 - accuracy: 0.5001 - val_loss: 0.4997 - val_accuracy: 0.5000
* Accuracy on training set: 50.00%
* Accuracy on test set: 50.00%

In [ ]: